import struct
import math
import logging


SHA256_LEN_BYTES = 256 // 8
SHA384_LEN_BYTES = 384 // 8
# rsa_pss_saltlen:32 (flag that used with openssl during signing)
SALT_LENGTH = 32
INT_LEN = 4

logger = logging.getLogger(__name__) 

# ASN. 1 headers
RSA_256_PUBLIC_KEY_HEADER_1 = str(
    [0x30, 0x82, 0x01, 0x22, 0x30, 0x0D, 0x06, 0x09,
     0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01,
     0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0F, 0x00,
     0x30, 0x82, 0x01, 0x0A, 0x02, 0x82, 0x01, 0x00]
)

RSA_384_PUBLIC_KEY_HEADER_1 = str(
    [0x30, 0x82, 0x01, 0xA2, 0x30, 0x0D, 0x06, 0x09,
     0x2A, 0x86, 0x48, 0x86, 0xF7, 0x0D, 0x01, 0x01,
     0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x8F, 0x00,
     0x30, 0x82, 0x01, 0x8A, 0x02, 0x82, 0x01, 0x80]
)

RSA_PUBLIC_KEY_HEADER_2 = str([0x02, 0x04])


def hex_string_to_int(hex_string):
    """Convets hex string to int

    Arguments:
            hex_string {str} -- hex string (e.g. "0xff")

    Returns:
        {int} --- int value of hex_str
    """
    try:
        return int(hex_string, 16)
    except TypeError:
        logger.exception('hex_string [%s] must be hex string' % (hex_string))
        raise TypeError('hex_string [%s] must be hex string' % (hex_string))


def bytes_to_int(data, byteorder='little'):
    """Convets bytes to int

    Arguments:
        data {bytes} -- bytes item (e.g. b'\x00')

    Keyword Arguments:
        byteorder {str} -- byteorder: little/big (default: {'little'})

    Returns:
        {int} -- int value of data
    """
    try:
        assert byteorder == 'little' or byteorder == 'big'
    except AssertionError:
        raise AssertionError('byteorder [%s] must be little or big' % (byteorder))
        logger.exception('byteorder [%s] must be little or big' % (byteorder))
    try:
        bytes_length = int(len(data))
        if bytes_length % INT_LEN != 0 or bytes_length == 0:
            if byteorder == 'little':
                data = data + (INT_LEN - bytes_length) * b'\x00'
            else:
                data = (INT_LEN - bytes_length) * b'\x00' + data
        int_length = int(len(data) / INT_LEN)
        if byteorder == 'big':
            s = struct.Struct('>' + int_length * 'I')
        else:
            s = struct.Struct('<' + int_length * 'I')
        return s.unpack(data)[0]
    except TypeError:
        logger.exception('data must be bytes')
        raise TypeError('data must be bytes')

def int_to_bytes(integer, byteorder='little'):
    """Convets int to bytes

    Arguments:
        integer {int} -- int to convert

    Keyword Arguments:
        byteorder {str} -- byteorder: little/big (default: {'little'})

    Returns:
        {bytes} -- bytes value of integer
    """
    
    try:
        assert byteorder == 'little' or byteorder == 'big'
    except AssertionError:
        raise AssertionError('byteorder [%s] must be little or big' % (byteorder))
        logger.exception('byteorder [%s] must be little or big' % (byteorder))
    try:
        length = ((integer.bit_length() + 7) // 8) // 4 or 1  # adjust to full byte
        if byteorder == 'big':
            s = struct.Struct('>' + length * 'I')
        else:
            s = struct.Struct('<' + length * 'I')
        return s.pack(integer)
    except TypeError:
        logger.exception('integer [%s] must be integer' %(integer))
        raise TypeError('integer [%s] must be integer' %(integer))
    


def extract_bits_from_int(integer, start, end):
    """Extract bytes from specified range 

    Arguments:
        integer {int} -- value to extract from
        start {int} -- start bit
        end {int} -- end bit

    Returns:
        {int} -- extracted bits as int
    """
    try:
        assert all(isinstance(_, int) for _ in [start, end]) and start < end and start >= 0 and end <=16*8
    except AssertionError:
        raise AssertionError('start [%s] and end [%s] range is not valid' %(start, end))
        logger.exception('start [%s] and end [%s] range is not valid' %(start, end))
    try:
        return (integer >> start) & ((1 << end - start) - 1)
    except Exception as e:
        logger.exception('integer [%s] must be integer' %(integer))
        raise TypeError('integer [%s] must be integer' %(integer))

def extract_bits_from_byte(data, start, end):
    """Extract bytes from specified range 

    Arguments:
        data {bytes} -- value to extract from
        start {int} -- start bit
        end {int} -- end bit

    Returns:
        {bytes} -- extracted bits as bytes
    """
    data = bytes_to_int(data)
    extracted = extract_bits_from_int(data, start, end)
    return int_to_bytes(extracted)
